import math

import torch
from torch import nn
from torch.nn import Parameter
import torch.nn.functional as F

from .encdec_multihead_attn_func import encdec_attn_func
from .fast_encdec_multihead_attn_func import fast_encdec_attn_func
from .fast_encdec_multihead_attn_norm_add_func import fast_encdec_attn_norm_add_func
from apex.normalization.fused_layer_norm import FusedLayerNorm

@torch.jit.script
def jit_dropout_add(x, residual, prob, is_training):
    # type: (Tensor, Tensor, float, bool) -> Tensor
    out = F.dropout(x, p=prob, training=True)
    out = residual + out
    return out


class EncdecMultiheadAttn(nn.Module):
    """Multi-headed attention.

    See "Attention Is All You Need" for more details.
    """

    def __init__(self, embed_dim, num_heads, dropout=0.0, bias=False, include_norm_add=False, impl="fast"):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.dropout = dropout
        self.head_dim = embed_dim // num_heads
        assert self.head_dim * num_heads == self.embed_dim, "embed_dim must be divisible by num_heads"
        self.bias = bias
        self.include_norm_add = include_norm_add
        self.impl = impl
        self.scaling = self.head_dim ** -0.5

        self.in_proj_weight_q = Parameter(torch.empty(embed_dim, embed_dim))
        self.in_proj_weight_kv = Parameter(torch.empty(2 * embed_dim, embed_dim))
        self.out_proj_weight = Parameter(torch.empty(embed_dim, embed_dim))
        if self.bias:
            assert impl != "fast", "ERROR! The Fast implementation does not support biases!"
            self.in_proj_bias_q = Parameter(torch.empty(embed_dim))
            self.in_proj_bias_kv = Parameter(torch.empty(2 * embed_dim))
            self.out_proj_bias = Parameter(torch.empty(embed_dim))
        else:
            self.register_parameter("in_proj_bias_q", None)
            self.register_parameter("in_proj_bias_kv", None)
            self.in_proj_bias_q = None
            self.in_proj_bias_kv = None
            self.out_proj_bias = None
        if self.include_norm_add:
            if impl == "fast":
                self.lyr_nrm_gamma_weights = Parameter(torch.empty(embed_dim))
                self.lyr_nrm_beta_weights = Parameter(torch.empty(embed_dim))
                self.lyr_nrm = None
            else:
                self.register_parameter("lyr_norm_gamma_weights", None)
                self.register_parameter("lyr_norm_beta_weights", None)
                self.lyr_nrm_gamma_weights = None
                self.lyr_nrm_beta_weights = None
                self.lyr_nrm = FusedLayerNorm(embed_dim)
        self.reset_parameters()

        if self.include_norm_add:
            if impl == "fast":
                self.attn_func = fast_encdec_attn_norm_add_func
            elif impl == "default":
                self.attn_func = encdec_attn_func
            else:
                assert False, "Unsupported impl: {} !".format(impl)
        else:
            if impl == "fast":
                self.attn_func = fast_encdec_attn_func
            elif impl == "default":
                self.attn_func = encdec_attn_func
            else:
                assert False, "Unsupported impl: {} !".format(impl)

    def reset_parameters(self):
        nn.init.xavier_uniform_(self.in_proj_weight_q)
        # in_proj_weight_kv has shape [2 * hidden, hidden] but it should be
        # initialized like a [hidden, hidden] matrix.
        # sqrt(6 / (hidden + hidden)) / sqrt(6 / (2 * hidden + hidden)) = sqrt(1.5)
        # therefore xavier_uniform gain should be set to sqrt(1.5).
        nn.init.xavier_uniform_(self.in_proj_weight_kv, gain=math.sqrt(1.5))
        nn.init.xavier_uniform_(self.out_proj_weight)
        if self.bias:
            nn.init.constant_(self.in_proj_bias_q, 0.0)
            nn.init.constant_(self.in_proj_bias_kv, 0.0)
            nn.init.constant_(self.out_proj_bias, 0.0)
        if self.include_norm_add:
            if self.impl == "fast":
                nn.init.ones_(self.lyr_nrm_gamma_weights)
                nn.init.zeros_(self.lyr_nrm_beta_weights)
            else:
                self.lyr_nrm.reset_parameters()

    def forward(self, query, key, value, key_padding_mask=None, need_weights=False, attn_mask=None, is_training=True):
        """Input shape: Time x Batch x Channel

        Self-attention can be implemented by passing in the same arguments for
        query, key and value. Future timesteps can be masked with the
        `mask_future_timesteps` argument. Padding elements can be excluded from
        the key by passing a binary ByteTensor (`key_padding_mask`) with shape:
        batch x src_len, where padding elements are indicated by 1s.
        """

        if key_padding_mask is not None:
            assert attn_mask is None, "ERROR attn_mask and key_padding_mask should not be both defined!"
            mask = key_padding_mask
        elif attn_mask is not None:
            mask = attn_mask
        else:
            mask = None

        if self.include_norm_add:
            if self.impl == "fast":
                outputs = self.attn_func(
                    attn_mask is not None,
                    is_training,
                    self.num_heads,
                    query,
                    key,
                    self.lyr_nrm_gamma_weights,
                    self.lyr_nrm_beta_weights,
                    self.in_proj_weight_q,
                    self.in_proj_weight_kv,
                    self.out_proj_weight,
                    mask,
                    self.dropout,
                )
            else:
                lyr_nrm_results = self.lyr_nrm(query)
                outputs = self.attn_func(
                    attn_mask is not None,
                    is_training,
                    self.num_heads,
                    self.scaling,
                    lyr_nrm_results,
                    key,
                    self.in_proj_weight_q,
                    self.in_proj_weight_kv,
                    self.out_proj_weight,
                    self.in_proj_bias_q,
                    self.in_proj_bias_kv,
                    self.out_proj_bias,
                    mask,
                    self.dropout,
                )
                if is_training:
                    outputs = jit_dropout_add(outputs, query, self.dropout, is_training)
                else:
                    outputs = outputs + query
        else:
            if self.impl == "fast":
                outputs = self.attn_func(
                    attn_mask is not None,
                    is_training,
                    self.num_heads,
                    query,
                    key,
                    self.in_proj_weight_q,
                    self.in_proj_weight_kv,
                    self.out_proj_weight,
                    mask,
                    self.dropout,
                )
            else:
                outputs = self.attn_func(
                    attn_mask is not None,
                    is_training,
                    self.num_heads,
                    self.scaling,
                    query,
                    key,
                    self.in_proj_weight_q,
                    self.in_proj_weight_kv,
                    self.out_proj_weight,
                    self.in_proj_bias_q,
                    self.in_proj_bias_kv,
                    self.out_proj_bias,
                    mask,
                    self.dropout,
                )

        return outputs, None
